import argparse
import pandas as pd
from transformers import pipeline
import torch
from datasets import Dataset
import warnings

def get_labels(batch, classes, template, classifier):
    """
    Process a batch of texts and return a dictionary with classification scores for each class.
    The function assumes that `classifier` can process batches of texts.
    """
    # Assuming `batch["text"]` is a list of texts
    texts = batch["text"]
    
    # Call the classifier with batched texts
    outputs = classifier(texts, classes, hypothesis_template=template, multi_label=True)
    
    # Initialize a dictionary to store outputs
    results = {template.format(label): [] for label in classes}
    
    # Populate the dictionary with scores from classifier outputs
    for output in outputs:
        for label, score in zip(output['labels'], output['scores']):
            results[template.format(label)].append(score)

    return results

def main(inp_file):
    df = pd.read_csv(inp_file)

    # Check if CUDA is available and set the device accordingly
    device = 0 if torch.cuda.is_available() else -1  # 0 for GPU, -1 for CPU
    assert device == 0

    # Initialize the classifier with the specified device
    classifier = pipeline("zero-shot-classification",
                          model="facebook/bart-large-mnli",
                          device=device)

    # Define classes for ZSC
#     topics = ['Israel',
#               'Ukraine',
#               'taxation',
#               'government spending',
#               'government corruption',
#               'crime',
#               'socialism',
#               'immigration',
#               'racism',
#               'abortion',
#               'guns',
#               'religion',
#               'the Middle East',
#               'Russia',
#               'China',
#              ]
#     topics_template = 'The text mentions {}.'

#     kind_of_news = ['US domestic news',
#                     'international news']
#     kind_of_news_template = 'This is {}.'

#     complains_against = ['The Democratic Party',
#                          'The Republican Party',
#                          'The US Federal government',
#                          'Joe Biden',
#                          'Donald Trump']
#     complains_against_template = '{} is bad.'

#     tone_of_news = ['conspiracy theory',
#                     'fake news article',
#                     'legitimate news article']
#     tone_of_news_template = 'This is a {}.'

    heros = ['The Republican party', 'Donald Trump', 'DeSantis', 'Russia', 'RFK jr']
    heros_template = '{} is good.'

    villians = ['The Democratic party', 'Joe Biden', 'The war in Ukraine', 'Big business', 'The pharmaceutical industry']
    villians_template = '{} is bad.'

    # Create a dataset from the list of texts
    df = df.reset_index(drop=True)
    dataset = Dataset.from_dict({"text": df.clean_text.values})

    # Suppress specific UserWarning from DataLoader
    warnings.filterwarnings("ignore", message="Length of IterableDataset")

    # Process and save results
    heros_results = dataset.map(lambda x: get_labels(x, heros, heros_template, classifier), batched=True)
    df = df.join(heros_results.to_pandas().drop(['text'], axis=1))
    df.to_csv(f'{inp_file}.bak', index=False)
    
    villians_results = dataset.map(lambda x: get_labels(x, villians, villians_template, classifier), batched=True)
    df = df.join(villians_results.to_pandas().drop(['text'], axis=1))
    df.to_csv(f'dc_weekly_with_hero_and_villain_scores.csv', index=False)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Get scores from NLI model as ZSC.")
    parser.add_argument('inp_file', type=str, help='Input file loc')
    
    args = parser.parse_args()
    main(args.inp_file)
